{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L_R1KOeEr1OM"
      },
      "source": [
        "\u003cp align=\"center\"\u003e\n",
        "  \u003ch1 align=\"center\"\u003eTAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement\u003c/h1\u003e\n",
        "  \u003cp align=\"center\"\u003e\n",
        "    \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://yangyi02.github.io/\"\u003eYi Yang\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=Jvi_XPAAAAAJ\"\u003eMel Vecerik\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=cnbENAEAAAAJ\"\u003eDilara Gokay\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~ankush/\"\u003eAnkush Gupta\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"http://people.csail.mit.edu/yusuf/\"\u003eYusuf Aytar\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~az/\"\u003eAndrew Zisserman\u003c/a\u003e\n",
        "  \u003c/p\u003e\n",
        "  \u003ch3 align=\"center\"\u003e\u003ca href=\"https://arxiv.org/abs/2306.08637\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://deepmind-tapir.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet/tree/main#running-tapir-locally\"\u003eLive Demo\u003c/a\u003e \u003c/h3\u003e\n",
        "  \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n",
        "\u003c/p\u003e\n",
        "\n",
        "\u003cp align=\"center\"\u003e\n",
        "  \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/horsejump_rainbow.gif\" width=\"70%\"/\u003e\u003cbr/\u003e\u003cbr/\u003e\n",
        "\u003c/p\u003e\n",
        "\u003cp\u003e\n",
        "  This visualization uses TAPIR to show how an object moves through space, even if the camera is tracking the object.  It begins by tracking points densely on a grid.  Then it estimates the camera motion as a homography (i.e., assuming either planar background or camera that rotates but does not move).  Any points that move according to that homography are removed.  Then we generate a \u0026ldquo;rainbow\u0026rdquo; visualization, where the tracked points leave \u0026ldquo;tails\u0026rdquo; that follow the camera motion, so it looks like the earlier positions of points are frozen in space.  This visualization was inspired by a similar one from \u003ca href=\"https://omnimotion.github.io/\"\u003eOmniMotion\u003c/a\u003e, although that one assumes ground-truth segmentations are available and models the camera as only 2D translation.\n",
        "\u003c/p\u003e\n",
        "\u003cp\u003e\n",
        "  Note that we consider this algorithm \u0026ldquo;semi-automatic\u0026rdquo; because you may need some tuning for pleasing results on arbitrary videos.  Tracking failures on the background may show up as foreground objects.  Results are sensitive to the outlier thresholds used in RANSAC and segmentation, and you may wish to discard short tracks.  You can sample in a different way (e.g. sampling points from multiple frames) and everything will work, but the \u003cfont face=\"Courier\"\u003eplot_tracks_tails\u003c/font\u003e function uses the input order of the points to choose colors, so you will have to sort the points appropriately.\n",
        "\u003c/p\u003e\n",
        "\n",
        ""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6yflCqOMaDJP"
      },
      "outputs": [],
      "source": [
        "# @title Install code and dependencies {form-width: \"25%\"}\n",
        "!pip install git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pmiHQJx4mvb1"
      },
      "outputs": [],
      "source": [
        "MODEL_TYPE = 'bootstapir'  # 'tapir' or 'bootstapir'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vnG2QEbxaH5Q"
      },
      "outputs": [],
      "source": [
        "# @title Download Model {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/checkpoints\n",
        "\n",
        "if MODEL_TYPE == \"tapir\":\n",
        "  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy\n",
        "else:\n",
        "  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy\n",
        "\n",
        "%ls tapnet/checkpoints"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NOHqUSIsmmd0"
      },
      "outputs": [],
      "source": [
        "# @title Imports {form-width: \"25%\"}\n",
        "\n",
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "import tree"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OK7F5PHdsBZ0"
      },
      "outputs": [],
      "source": [
        "from tapnet.models import tapir_model\n",
        "from tapnet.utils import model_utils\n",
        "from tapnet.utils import transforms\n",
        "from tapnet.utils import viz_utils\n",
        "\n",
        "# @title Load Checkpoint {form-width: \"25%\"}\n",
        "if MODEL_TYPE == 'tapir':\n",
        "  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'\n",
        "else:\n",
        "  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'\n",
        "ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()\n",
        "params, state = ckpt_state['params'], ckpt_state['state']\n",
        "\n",
        "kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)\n",
        "if MODEL_TYPE == 'bootstapir':\n",
        "  kwargs.update(\n",
        "      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)\n",
        "  )\n",
        "tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "64dc5wE7KkC-"
      },
      "source": [
        "## Load and Build Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K2daaUz12gnE"
      },
      "outputs": [],
      "source": [
        "# @title Utilities for model inference {form-width: \"25%\"}\n",
        "\n",
        "\n",
        "def sample_grid_points(frame_idx, height, width, stride=1):\n",
        "  \"\"\"Sample grid points with (time height, width) order.\"\"\"\n",
        "  points = np.mgrid[stride // 2 : height : stride, stride // 2 : width : stride]\n",
        "  points = points.transpose(1, 2, 0)\n",
        "  out_height, out_width = points.shape[0:2]\n",
        "  frame_idx = np.ones((out_height, out_width, 1)) * frame_idx\n",
        "  points = np.concatenate((frame_idx, points), axis=-1).astype(np.int32)\n",
        "  points = points.reshape(-1, 3)  # [out_height*out_width, 3]\n",
        "  return points"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dTeDYLaRE2zs"
      },
      "outputs": [],
      "source": [
        "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/examplar_videos\n",
        "\n",
        "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
        "\n",
        "orig_frames = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
        "height, width = orig_frames.shape[1:3]\n",
        "media.show_video(orig_frames, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNn1sLNaeST8"
      },
      "outputs": [],
      "source": [
        "# @title Inference function {form-width: \"25%\"}\n",
        "\n",
        "resize_height = 256  # @param {type: \"integer\"}\n",
        "resize_width = 256  # @param {type: \"integer\"}\n",
        "stride = 8  # @param {type: \"integer\"}\n",
        "query_frame = 0  # @param {type: \"integer\"}\n",
        "\n",
        "frames = media.resize_video(orig_frames, (resize_height, resize_width))\n",
        "frames = model_utils.preprocess_frames(frames[None])\n",
        "feature_grids = tapir.get_feature_grids(frames, is_training=False)\n",
        "chunk_size = 64\n",
        "height, width = orig_frames.shape[1:3]\n",
        "\n",
        "\n",
        "def chunk_inference(query_points):\n",
        "  query_points = query_points.astype(np.float32)[None]\n",
        "\n",
        "  outputs = tapir(\n",
        "      video=frames,\n",
        "      query_points=query_points,\n",
        "      is_training=False,\n",
        "      query_chunk_size=chunk_size,\n",
        "      feature_grids=feature_grids,\n",
        "  )\n",
        "  tracks, occlusions, expected_dist = (\n",
        "      outputs[\"tracks\"],\n",
        "      outputs[\"occlusion\"],\n",
        "      outputs[\"expected_dist\"],\n",
        "  )\n",
        "\n",
        "  # Binarize occlusions\n",
        "  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)\n",
        "  return tracks[0], visibles[0]\n",
        "\n",
        "\n",
        "chunk_inference = jax.jit(chunk_inference)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_5A-j_PWUFmA"
      },
      "outputs": [],
      "source": [
        "# @title Predict semi-dense point tracks {form-width: \"25%\"}\n",
        "%%time\n",
        "\n",
        "\n",
        "query_points = sample_grid_points(\n",
        "    query_frame, resize_height, resize_width, stride\n",
        ")\n",
        "\n",
        "tracks = []\n",
        "visibles = []\n",
        "for i in range(0, query_points.shape[0], chunk_size):\n",
        "  query_points_chunk = query_points[i : i + chunk_size]\n",
        "  num_extra = chunk_size - query_points_chunk.shape[0]\n",
        "  if num_extra \u003e 0:\n",
        "    query_points_chunk = np.concatenate(\n",
        "        [query_points_chunk, np.zeros([num_extra, 3])], axis=0\n",
        "    )\n",
        "  tracks2, visibles2 = chunk_inference(query_points_chunk)\n",
        "  if num_extra \u003e 0:\n",
        "    tracks2 = tracks2[:-num_extra]\n",
        "    visibles2 = visibles2[:-num_extra]\n",
        "  tracks.append(tracks2)\n",
        "  visibles.append(visibles2)\n",
        "tracks = jnp.concatenate(tracks, axis=0)\n",
        "visibles = jnp.concatenate(visibles, axis=0)\n",
        "\n",
        "tracks = transforms.convert_grid_coordinates(\n",
        "    tracks, (resize_width, resize_height), (width, height)\n",
        ")\n",
        "\n",
        "# We show the point tracks without rainbows so you can see the input.\n",
        "video = viz_utils.plot_tracks_v2(orig_frames, tracks, 1.0 - visibles)\n",
        "media.show_video(video, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vyl_hSxaJFsz"
      },
      "outputs": [],
      "source": [
        "# The inlier point threshold for ransac, specified in normalized coordinates\n",
        "# (points are rescaled to the range [0, 1] for optimization).\n",
        "ransac_inlier_threshold = 0.07  # @param {type: \"number\"}\n",
        "# What fraction of points need to be inliers for RANSAC to consider a trajectory\n",
        "# to be trustworthy for estimating the homography.\n",
        "ransac_track_inlier_frac = 0.95  # @param {type: \"number\"}\n",
        "# After initial RANSAC, how many refinement passes to adjust the homographies\n",
        "# based on tracks that have been deemed trustworthy.\n",
        "num_refinement_passes = 2  # @param {type: \"number\"}\n",
        "# After homographies are estimated, consider points to be outliers if they are\n",
        "# further than this threshold.\n",
        "foreground_inlier_threshold = 0.07  # @param {type: \"number\"}\n",
        "# After homographies are estimated, consider tracks to be part of the foreground\n",
        "# if less than this fraction of its points are inliers.\n",
        "foreground_frac = 0.6  # @param {type: \"number\"}\n",
        "\n",
        "\n",
        "occluded = 1.0 - visibles\n",
        "homogs, err, canonical = viz_utils.get_homographies_wrt_frame(\n",
        "    tracks,\n",
        "    occluded,\n",
        "    [width, height],\n",
        "    thresh=ransac_inlier_threshold,\n",
        "    outlier_point_threshold=ransac_track_inlier_frac,\n",
        "    num_refinement_passes=num_refinement_passes,\n",
        ")\n",
        "\n",
        "inliers = (err \u003c np.square(foreground_inlier_threshold)) * visibles\n",
        "inlier_ct = np.sum(inliers, axis=-1)\n",
        "ratio = inlier_ct / np.maximum(1.0, np.sum(visibles, axis=1))\n",
        "is_fg = ratio \u003c= foreground_frac\n",
        "video = viz_utils.plot_tracks_tails(\n",
        "    orig_frames, tracks[is_fg], occluded[is_fg], homogs\n",
        ")\n",
        "media.show_video(video, fps=12)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
